[weather] Add Fgn model #1660
Conversation
Greptile SummaryThis PR adds a new
Important Files Changed
Reviews (1): Last reviewed commit: "Merge branch 'main' into fgn" | Re-trigger Greptile |
| for k in range(ar_steps): | ||
| members = [] | ||
| for n in range(num_samples): | ||
| hist_n = per_member_hist[:, n] | ||
| pred_n = self._step_ensemble( | ||
| history=hist_n, | ||
| background=background, | ||
| invariants=invariants, | ||
| num_samples=1, | ||
| )[:, 0] | ||
| members.append(pred_n) | ||
| preds = torch.stack(members, dim=1) # (B, N, C, H, W) |
There was a problem hiding this comment.
Redundant wrapping of single model forward pass through
_step_ensemble. The outer for n in range(num_samples) loop calls _step_ensemble(..., num_samples=1) for each member, which internally runs its own for _ in range(1) loop. This double-loops and makes the code harder to follow; calling the model directly (as the validation loop in _run_validation_metrics does) would be cleaner and avoids the vestigial num_samples=1 sentinel.
There was a problem hiding this comment.
Removed in the latest push. _step_ensemble was indeed dead code — _loss and _validation_loss both inline the ensemble loop directly. The double-loop pattern was from an earlier version.
Remove cluster-specific slurm scripts (local paths), untrack FGN.md (dev notes), add .gitignore, and fix README references. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…e indirection - Fix inference _rollout: torch.cat([history[:, 1:], next_frame.unsqueeze(1)]) so history window slides correctly for any history_frames value, not just 2 - Remove unimplemented amp config field from TrainingConfig and default.yaml - Inline model call in _loss AR loop instead of routing through _step_ensemble with num_samples=1 (each member needs its own history, so the single-call collapse doesn't apply; direct call is cleaner) Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Run ruff format + fix across all fgn/ Python files - Remove unused imports (Sequence, Callable, ShardTensor, math, torch) - Replace assert with if/raise (S101), fix import order (I001), simplify loops to list-comprehension/extend (PERF401/102) - Add noqa: E402 on intentional post-path-insert imports in stage4 - Upgrade FGNUNet docstring to MOD-003 (r-string, NumPy sections, Parameters/Forward/Outputs with LaTeX shapes, Examples) - Add CHANGELOG.md entry under [2.1.0a0] Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…nfig - utils/metrics.py: add energy_score_per_lead — fair energy score (multivariate CRPS) over the variable axis with spatial subsampling; new in earth2studio 0.13.0, captures cross-channel calibration - utils/trainer.py: wire energy_score_per_lead into validation hook, save to metrics.npz and plot energy_score_vs_lead.png - config/fgn.yaml: base Hydra config required by train.py (@hydra.main config_name="fgn") with model defaults and dataset skeleton; was missing, causing Hydra to error without all overrides - config/fgn_arco.yaml: practical single-GPU ARCO ERA5 training config (2018–2022 train / 2023 val, hidden_channels=64, 5000 steps, full loss weights) for runs beyond the smoke-test default Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- datasets/__init__.py: auto-discovery registry (mirrors stormcast) that populates dataset_classes dict by scanning all FGNDataset subclasses; fixes ImportError since the regular HF `datasets` package beat the namespace package without __init__.py - datasets/dataset.py: FGNDataset ABC (state_channels, background_channels, image_shape, get_invariants, output_only_channels) + worker_init; mirrors stormcast/datasets/dataset.py convention - utils/loss.py: fair_crps (paper eq. 4), ensemble_mean_mse, build_channel_weights (§2.2.3 GraphCast scheme with z halved), build_area_weights (cos-lat normalised to unit mean) All three files existed locally before the branch cleanup but were never committed; this adds them properly. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Use _make_train_iter() to route through sharded_data_iter when domain parallelism is active (mirrors StormCast trainer pattern); plain DDP path gets an infinite-restart iterator instead of the old bare iter() - Wrap both model forward sites in torch.autocast(bfloat16) and call .float() on preds to keep loss computation in fp32; halves activation memory at full 721x1440 resolution on H100 80GB - train_fgn.sh: batch_size=1, domain_parallel_size=1 (DDP), run_id Hydra string quoting fix, PYTORCH_CUDA_ALLOC_CONF=expandable_segments Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
With domain_parallel_size=1 and 2 GPUs, data_parallel_size=2 so local_batch = batch_size // 2; batch_size=1 → local_batch=0 causing BatchSampler ValueError. Use batch_size=2 (global) = 1 per GPU. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Mark tp06, multi-rank sanity, AR stage scheduler, bad-seed detector as done. Add status for currently running 5000-step job 99807. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Replaces the stale MVP-scaffold README with a full recipe README modelled on stormcast/README.md: problem overview, dataset (ARCO), getting started, configuration table, training (single-GPU / torchrun / SLURM), AR fine-tuning schedule, inference, custom dataset interface, memory guidance, and references. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Link to developers.google.com/weathernext/guides/models and model-specs-vmg from the intro and References section - Clarify production deployment: 64 members (4 seeds × 16 each) - Note u100m/v100m omission: ERA5/ARCO lacks 100m winds Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…elper Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Replace _ar_rollout (returns full (B,K,M,C,H,W) = ~51 GB) with _ar_rollout_steps generator (yields (B,M,C,H,W) per step, ~2.5 GB). All metrics computed per-step via unsqueeze(1) trick. Mirrors earth2studio GenCast/GraphCast yield+del pattern. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- lr: 3e-4 → 8e-5 (paper stage 2-4 value) - weight_decay: 1e-4 → 0.1 (paper value) - Add linear warmup (1000 steps) + cosine decay LR schedule - Save/restore scheduler state in checkpoints - Log lr in progress line Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Paper stage 4: 8e-5 (1AR) → 8e-6 (2AR) → 8e-7 (3-8AR), not the incorrect 8e-5/8e-5/8e-5/8e-6/8e-6/8e-7/8e-7/8e-7 we had. Also thread lr_warmup_steps (800/400/100) through build_stage_cfg. DEV_STAGES updated to mirror paper's LR ratios. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
_validation_loss() now returns (scalar, per_channel_mse[C]) alongside
the aggregate loss. Per-channel values are logged at each validation
step, mirroring StormCast's log_value(f'loss/valid/{ch}') convention.
Cheap: uses a single deterministic forward pass (latent=0), no ensemble.
Scheduler state saved/restored via save_checkpoint/load_checkpoint.
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…alues - training.torch_compile flag: compiles model before FSDP wrapping, skipped when ShardTensor active (matches CorrDiff/GraphCast pattern) - Unwrap _orig_mod for save_checkpoint (OptimizedModule has no __len__) - default.yaml: lr 3e-4→8e-5, weight_decay 1e-4→0.1 (paper Table A.2), add lr_warmup_steps/lr_min/torch_compile fields Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Was overriding lr=3e-4 and missing weight_decay/warmup fields. Now: lr=8e-5, weight_decay=0.1, lr_warmup_steps=1000 (Stage 3 values). Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Matches DLWP/FCN-AFNO/diagnostic pattern from physicsnemo.utils.logging. - LaunchLogger.initialize() called at trainer startup (no-op by default) - train loop uses LaunchLogger context for loss/lr minibatch logging - val loop logs val_loss + per-channel MSE via LaunchLogger - use_wandb/use_mlflow/wandb_project config flags (all off by default) - Fix _run_validation_metrics plot_power_spectra call (new signature) Enable W&B: training.use_wandb=true training.wandb_project=my-project Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- add training.amp bool (default true) mirroring graphcast convention; autocast now reads self.amp instead of hardcoded is_available() - remove _step_ensemble which was dead code (never called) Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- latent_dim default 16→32 (paper §2.3: z∈N(0,1)^32) - eval CRPS switches to biased estimator (paper §4.1): deep-ensemble violates the independence assumption of the fair variant - applies to both eval.py (e2s_crps fair=False) and trainer _run_validation_metrics (kcrps biased=True) Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…notes Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Paper §4.1 / Figure 2g-h evaluates REV at the 99.99th percentile (z ≈ 3.72); previous default stopped at p99. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Closes the Figure 4 gap from the paper gap analysis: - tc_eval.py: AR rollout + earth2studio TCTrackerWuDuan per IC, IBTrACS ground-truth pairing, position error + track REV metrics - metrics.py: plot_tc_position_error / plot_tc_track_rev (Fig 4a-b) - arco.py: expose init_time in __getitem__ return dict Requires: pip install cucim-cu12 Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Richardson 2000: v_clim = max(0, μ-r), v_perf = μ*(1-r). Prior code had them swapped, producing wrong REV denominator. Also add frac_active >= 0.5 guard on position error per §4.3: ensemble-mean position only counted when ≥50% of members still forecast the cyclone. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
After each save, rank 0 deletes all but the 2 most recent .mdlus and checkpoint.*.pt files to prevent disk exhaustion. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…=12) Switch from a tiny 449K-param convolutional U-Net to physicsnemo's DiT. Patchifying 721x1440 into 181x360=65k tokens before attention gives 16x memory reduction, enabling batch>1 and larger model capacity (~33M params). z ~ N(0,I)^32 conditions all transformer layers via AdaLN-Zero (passed as condition= to DiT), matching paper §2.3's global conditional layer-norm exactly. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Remove unused torch.nn import from nn.py (F401) - Replace append loop with list comprehension in eval.py (PERF401) - Add patch_size/hidden_size/depth/num_heads to ModelConfig (extra=forbid) - Remove UNet-only fields (hidden_channels, group_norm_groups) Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
PatchEmbed2D allocates pos_embed with floor(H/ps) but pads the input at runtime, producing ceil(H/ps) tokens — causing a size mismatch on any non-divisible grid (ERA5 721 with ps=4: 180 slots vs 181 tokens). Pre-pad the input to the nearest patch multiple before DiT so its internal padding path never fires, matching StormCast's practice of always passing divisible resolutions. Crop the output back afterward. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
PhysicsNeMo Pull Request
Description
Adds the FGN (Functional Generative Networks) weather model example from arXiv:2506.10772.
FGN is a latent-conditioned UNet trained with fair-CRPS to generate calibrated ensemble forecasts. This PR includes:
examples/weather/fgn/— training, evaluation, and inference scriptsChecklist
Dependencies
No new dependencies beyond what is already in the PhysicsNeMo environment.
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI's assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.